""" DiffGro Skill Termination Predictor Implementation """
from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import time
import gym
import jax
import jax.numpy as jnp
import numpy as np
import haiku as hk
import wandb

from sb3_jax.common.offline_algorithm import OfflineAlgorithm
from sb3_jax.common.buffers import BaseBuffer
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.jax_utils import jax_print, jit_optimize

from diffgro.utils.utils import print_b
from diffgro.common.buffers import PredictorBuffer
from diffgro.diffgro.policies import DiffGroPredictorPolicy


class DiffGroPredictor(OfflineAlgorithm):
    def __init__(
        self,
        policy: Union[str, Type[DiffGroPredictorPolicy]],
        env: Union[GymEnv, str],
        replay_buffer: Type[BaseBuffer] = PredictorBuffer,
        learning_rate: float = 3e-4,
        batch_size: int = 256, 
        gamma: float = 0.99,
        gradient_steps: int = 1,
        wandb_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = 1,
        _init_setup_model: bool = True,
    ):
        super(DiffGroPredictor, self).__init__(
            policy,
            env,
            replay_buffer=replay_buffer,
            learning_rate=learning_rate,
            batch_size=batch_size,
            gamma=gamma,
            gradient_steps=gradient_steps,
            tensorboard_log=None,
            wandb_log=wandb_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            create_eval_env=False,
            seed=seed,
            _init_setup_model=False,
            supported_action_spaces=(gym.spaces.Box),
            support_multi_env=False,
        )
        self.learning_rate = learning_rate

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        self.set_random_seed(self.seed)
        self.policy = self.policy_class(
            self.observation_space,
            self.action_space,
            self.learning_rate,
            seed=self.seed,
            **self.policy_kwargs,
        )
        self._create_aliases()

    def _create_aliases(self) -> None:
        self.prd = self.policy

    def train(self, gradient_steps: int, batch_size: int = 256) -> None:
        losses, acces = [], []
        for gradient_step in range(gradient_steps):
            self._n_updates += 1
            replay_data = self.replay_buffer.sample(batch_size) 
            obs_0   = replay_data.start_observations
            p_obs_0 = self.policy.preprocess(obs_0, training=False)
            obs     = replay_data.observations
            p_obs   = self.policy.preprocess(obs.reshape(-1, self.policy.obs_dim), training=True)
            p_obs   = p_obs.reshape(batch_size, -1)
            act     = replay_data.actions.reshape(batch_size, -1)
            skill   = replay_data.skills
            done    = replay_data.dones

            self.prd.optim_state, self.prd.params, loss, info = jit_optimize(
                self._loss,
                self.prd.optim,
                self.prd.optim_state,
                self.prd.params,
                max_grad_norm=None,
                obs_0=p_obs_0,
                obs=p_obs,
                act=act,
                skill=skill,
                done=done,
            )
            losses.append(loss)
            acces.append(info['acc'])
        
        wandb_log = {"time/total_timesteps": self.num_timesteps}
        self.logger.record("train/pd/loss", np.mean(losses))
        self.logger.record("train/pd/acc", np.mean(acces))
        wandb_log.update({"train/pd/loss": np.mean(losses)})
        wandb_log.update({"train/pd/acc": np.mean(acces)})
        if self.wandb_log is not None:
            wandb.log(wandb_log)
        
    def _loss(
        self,
        pd_params: hk.Params,
        obs_0: jax.Array,
        obs: jax.Array,
        act: jax.Array,
        skill: jax.Array,
        done: jax.Array,
        rng=None,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_size = obs.shape[0]

        # inference
        batch_dict = {"obs_0": obs_0, "obs": obs, "act": act , "skill": skill}
        pred = self.prd._pd(batch_dict, pd_params, rng)
        pred = jnp.clip(pred, 1e-7, 1 - 1e-7)
        
        # bce loss
        loss = done * jnp.log(pred) + (1 - done) * jnp.log(1 - pred)
        loss = -np.mean(loss)
        
        # accuracy
        acc = done * (pred > 0.5) + (1 - done) * (pred < 0.5)
        acc = jnp.sum(acc) / batch_size
        return loss, {"acc": acc}

    def learn(
        self,
        total_timesteps: Tuple[int, int],
        callback: MaybeCallback = None,
        log_interval: int = 1,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "DiffGroPredictor",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
    ) -> "DiffGroPredictor":
        self.log_interval = log_interval

        # wandb configs
        if self.wandb_log is not None:
            self.wandb_config = dict(
                time=time.ctime(),
                algo='diffgro/predictor',
                tag=self.wandb_log['tag'],
                learning_rate=self.learning_rate,
                batch_size=self.batch_size,
                gamma=self.gamma,
                gradient_steps=self.gradient_steps,
                seed=self.seed,
            )
            self.wandb_config.update(self.policy._get_constructor_parameters())
        
        total_timesteps, callback = self._setup_learn(
            total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
        )
        callback.on_training_start(locals(), globals())

        # 2. learn policy module
        start_time = time.time()
        num_timesteps = 0
        while num_timesteps < total_timesteps:
            self.train(gradient_steps=self.gradient_steps, batch_size=self.batch_size)

            self.num_timesteps += 1
            num_timesteps += 1
            if log_interval is not None and num_timesteps % log_interval == 0:
                fps = int(num_timesteps / (time.time() - start_time))
                self.logger.record("time/fps", fps)
                self.logger.record("time/time_elapsed", int(time.time() - start_time), exclude="tensorboard")
                self.logger.record("time/total_timesteps", num_timesteps, exclude="tensorboard")
                self.logger.dump(step=num_timesteps)
            
            callback.update_locals(locals())
            if callback.on_step() is False:
                return False
        
        callback.on_training_end()
        return self

    def load_params(self, path: str) -> None:
        print_b(f"[diffgro/predictor] : loading params")
        data, params = load_from_zip_file(path, verbose=1)
        self._load_jax_params(params)
        self._load_norm_layer(path)

    def _save_jax_params(self) -> Dict[str, hk.Params]:
        params_dict = {} 
        params_dict["pd_params"] = self.prd.params
        return params_dict

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        self.prd._load_jax_params(params)

    def _save_norm_layer(self, path: str) -> None:
        if self.policy.normalization_class is not None:
            self.policy.normalization_layer.save(path)

    def _load_norm_layer(self, path: str) -> None:
        if self.policy.normalization_class is not None:
            self.policy.normalization_layer = self.policy.normalization_layer.load(path)
